# -*- coding: utf-8 -*-
"""Location: ./tests/unit/mcpgateway/test_rpc_tool_invocation.py
Copyright 2025
SPDX-License-Identifier: Apache-2.0
Authors: Mihai Criveti

Test RPC tool invocation after PR #746 changes.
"""

# Standard
from unittest.mock import AsyncMock, MagicMock, patch

# Third-Party
from fastapi.testclient import TestClient
import pytest
from sqlalchemy.orm import Session

# First-Party
from mcpgateway.main import app
from mcpgateway.common.models import Tool
from mcpgateway.services.tool_service import ToolService


@pytest.fixture
def client():
    """Create a test client."""
    return TestClient(app)


@pytest.fixture
def mock_db():
    """Create a mock database session."""
    return MagicMock(spec=Session)


@pytest.fixture
def mock_tool_service():
    """Create a mock tool service."""
    service = AsyncMock(spec=ToolService)
    return service


@pytest.fixture
def sample_tool():
    """Create a sample tool for testing."""
    return Tool(
        name="test_tool",
        url="http://localhost:8000/test",
        description="A test tool",
        input_schema={"type": "object", "properties": {"query": {"type": "string"}, "limit": {"type": "number", "default": 5}}, "required": ["query"]},
    )


class TestRPCToolInvocation:
    """Test class for RPC tool invocation."""

    def test_tools_call_method_new_format(self, client, mock_db):
        """Test tool invocation using the new tools/call method format."""
        with patch("mcpgateway.config.settings.auth_required", False):
            with patch("mcpgateway.main.get_db", return_value=mock_db):
                with patch("mcpgateway.main.tool_service.invoke_tool", new_callable=AsyncMock) as mock_invoke:
                    mock_invoke.return_value = {"result": "success", "data": "test data"}

                    request_body = {"jsonrpc": "2.0", "method": "tools/call", "params": {"name": "test_tool", "arguments": {"query": "test", "limit": 5}}, "id": 1}

                    response = client.post("/rpc", json=request_body)

                    assert response.status_code == 200
                    result = response.json()
                    assert result["jsonrpc"] == "2.0"
                    assert "result" in result
                    assert result["id"] == 1

                    mock_invoke.assert_called_once()
                    call_args = mock_invoke.call_args
                    assert call_args.kwargs["name"] == "test_tool"
                    assert call_args.kwargs["arguments"] == {"query": "test", "limit": 5}

    def test_direct_tool_invocation_fails(self, client, mock_db):
        """Test that direct tool invocation (old format) now fails with 'Invalid method'."""
        with patch("mcpgateway.config.settings.auth_required", False):
            with patch("mcpgateway.main.get_db", return_value=mock_db):
                request_body = {"jsonrpc": "2.0", "method": "test_tool", "params": {"query": "test", "limit": 5}, "id": 1}  # Direct tool name as method (old format)

                response = client.post("/rpc", json=request_body)

                assert response.status_code == 200
                result = response.json()
                assert result["jsonrpc"] == "2.0"
                assert "error" in result
                assert result["error"]["code"] == -32000
                assert result["error"]["message"] == "Invalid method"
                assert result["error"]["data"] == {"query": "test", "limit": 5}
                assert result["id"] == 1

    def test_tools_list_method(self, client, mock_db):
        """Test the tools/list method."""
        with patch("mcpgateway.config.settings.auth_required", False):
            with patch("mcpgateway.main.get_db", return_value=mock_db):
                with patch("mcpgateway.main.tool_service.list_tools", new_callable=AsyncMock) as mock_list:
                    sample_tool = MagicMock()
                    sample_tool.model_dump.return_value = {"name": "test_tool", "description": "A test tool"}
                    mock_list.return_value = ([sample_tool], None)

                    request_body = {"jsonrpc": "2.0", "method": "tools/list", "params": {}, "id": 2}

                    response = client.post("/rpc", json=request_body)

                    assert response.status_code == 200
                    result = response.json()
                    assert result["jsonrpc"] == "2.0"
                    assert "result" in result
                    assert "tools" in result["result"]
                    assert len(result["result"]["tools"]) == 1
                    assert result["result"]["tools"][0]["name"] == "test_tool"

    def test_resources_read_method(self, client, mock_db):
        """Test the resources/read method."""
        with patch("mcpgateway.config.settings.auth_required", False):
            with patch("mcpgateway.main.get_db", return_value=mock_db):
                with patch("mcpgateway.main.resource_service.read_resource", new_callable=AsyncMock) as mock_read:
                    mock_read.return_value = {"uri": "test://resource", "content": "test content"}

                    request_body = {"jsonrpc": "2.0", "method": "resources/read", "params": {"uri": "test://resource"}, "id": 3}

                    response = client.post("/rpc", json=request_body)

                    assert response.status_code == 200
                    result = response.json()
                    assert result["jsonrpc"] == "2.0"
                    assert "result" in result
                    assert "contents" in result["result"]

    def test_prompts_get_method(self, client, mock_db):
        """Test the prompts/get method."""
        with patch("mcpgateway.config.settings.auth_required", False):
            with patch("mcpgateway.main.get_db", return_value=mock_db):
                with patch("mcpgateway.main.prompt_service.get_prompt", new_callable=AsyncMock) as mock_get:
                    mock_prompt = MagicMock()
                    mock_prompt.model_dump.return_value = {"name": "test_prompt", "description": "A test prompt", "messages": []}
                    mock_get.return_value = mock_prompt

                    request_body = {"jsonrpc": "2.0", "method": "prompts/get", "params": {"name": "test_prompt", "arguments": {}}, "id": 4}

                    response = client.post("/rpc", json=request_body)

                    assert response.status_code == 200
                    result = response.json()
                    assert result["jsonrpc"] == "2.0"
                    assert "result" in result

    def test_initialize_method(self, client, mock_db):
        """Test the initialize method."""
        with patch("mcpgateway.config.settings.auth_required", False):
            with patch("mcpgateway.main.get_db", return_value=mock_db):
                with patch("mcpgateway.main.session_registry.handle_initialize_logic", new_callable=AsyncMock) as mock_init:
                    mock_init.return_value = MagicMock(model_dump=MagicMock(return_value={"protocolVersion": "1.0", "capabilities": {}, "serverInfo": {"name": "test-server"}}))

                    request_body = {"jsonrpc": "2.0", "method": "initialize", "params": {"protocolVersion": "1.0", "capabilities": {}, "clientInfo": {"name": "test-client"}}, "id": 5}

                    response = client.post("/rpc", json=request_body)

                    assert response.status_code == 200
                    result = response.json()
                    assert result["jsonrpc"] == "2.0"
                    assert "result" in result
                    assert result["result"]["protocolVersion"] == "1.0"

    @pytest.mark.parametrize(
        "method,expected_result_key",
        [
            ("tools/list", "tools"),
            ("resources/list", "resources"),
            ("prompts/list", "prompts"),
            ("list_gateways", "gateways"),
            ("list_roots", "roots"),
        ],
    )
    def test_list_methods_return_proper_structure(self, client, mock_db, method, expected_result_key):
        """Test that all list methods return results in the proper structure."""
        with patch("mcpgateway.config.settings.auth_required", False):
            with patch("mcpgateway.main.get_db", return_value=mock_db):
                # Mock all possible service methods
                with patch("mcpgateway.main.tool_service.list_tools", new_callable=AsyncMock, return_value=([], None)):
                    with patch("mcpgateway.main.resource_service.list_resources", new_callable=AsyncMock, return_value=([], None)):
                        with patch("mcpgateway.main.prompt_service.list_prompts", new_callable=AsyncMock, return_value=([], None)):
                            with patch("mcpgateway.main.gateway_service.list_gateways", new_callable=AsyncMock, return_value=[]):
                                with patch("mcpgateway.main.root_service.list_roots", new_callable=AsyncMock, return_value=[]):
                                    request_body = {"jsonrpc": "2.0", "method": method, "params": {}, "id": 100}

                                    response = client.post("/rpc", json=request_body)

                                    assert response.status_code == 200
                                    result = response.json()
                                    assert result["jsonrpc"] == "2.0"
                                    assert "result" in result
                                    assert expected_result_key in result["result"]
                                    assert isinstance(result["result"][expected_result_key], list)

    def test_unknown_method_returns_error(self, client, mock_db):
        """Test that unknown methods return an appropriate error."""
        with patch("mcpgateway.config.settings.auth_required", False):
            with patch("mcpgateway.main.get_db", return_value=mock_db):
                request_body = {"jsonrpc": "2.0", "method": "unknown/method", "params": {}, "id": 999}

                response = client.post("/rpc", json=request_body)

                assert response.status_code == 200
                result = response.json()
                assert result["jsonrpc"] == "2.0"
                assert "error" in result
                assert result["error"]["code"] == -32000
                assert result["error"]["message"] == "Invalid method"
                assert result["id"] == 999


if __name__ == "__main__":
    pytest.main([__file__, "-v"])
